Tensorflow实现卷积神经网络分类Cifar |
您所在的位置:网站首页 › cifar 10数据集分类 › Tensorflow实现卷积神经网络分类Cifar |
数据准备 去官网下载好这些数据。然后解压 导入相关模块 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import pickle %matplotlib inline写一个类,处理数据。这个类实现了独热标签、返回指定数量训练集、返回测试集 class Cifar: def __init__(self, onehot = False): self.debug = False self.onehot= onehot self.CIFAR = './CIFAR/10/' self.current_batch_lib = 1 self.current_batch_idx_start = 0 self.label_name = np.load(self.CIFAR + "batches.meta")['label_names'] self.label_name_cn=['灰机', '汽车', '鸟', '喵', '长脖怪', '大黄', '蛤膜', '马', '船', '卡车'] self.batch_data, self.batch_label = self.__get_batch(self.current_batch_lib) #获取标签对应英文名称 #参数:标签序号 #返回:英文名 def label_name(self, label): return self.label_name[label] #私有方法,获取一个文件的batch #参数第几个batch文件 #返回RGB通道图像和标签 def __get_batch(self, idx): data, label = np.array([]), np.array([]) file = self.CIFAR + 'data_batch_' + str(idx) if idx == 'test_batch': file = self.CIFAR + 'test_batch' if self.debug: print('unpickleing ', file) with open(file, 'rb') as fo: dic = pickle.load(fo,encoding='bytes') fo.close() data = np.array(dic[b'data']) / 255 label= dic[b'labels'] return data.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1), np.array(label) #私有方法,label转独热码 #参数:标签 #返回独热码 def __label_onehot(self, labels): hot = np.zeros((len(labels), 10)) index_offset = np.arange(len(labels)) * 10 hot.flat[index_offset + labels.ravel()] = 1 return hot #获取size大小的训练集 #参数:训练集大小 #返回:(RGB图像,标签)。注:当前batch不够时,返回的训练集数量可能小于size def get_next_batch(self, size = 50): if self.current_batch_idx_start >= len(self.batch_label): #当前batch不够用 self.current_batch_lib = self.current_batch_lib + 1 if self.current_batch_lib > 5: #5个文件全搞过了,返回第1个文件重新来 self.current_batch_lib -= 5 self.current_batch_idx_start = 0 self.batch_data, self.batch_label = self.__get_batch(self.current_batch_lib) if self.debug: print('getting data in %d lib, offset: %d' % (self.current_batch_lib, self.current_batch_idx_start)) if self.onehot: #是否独热 ret = self.batch_data[self.current_batch_idx_start:self.current_batch_idx_start + size], \ self.__label_onehot(self.batch_label[self.current_batch_idx_start:self.current_batch_idx_start + size]) else: ret = self.batch_data[self.current_batch_idx_start:self.current_batch_idx_start + size], \ self.batch_label[self.current_batch_idx_start:self.current_batch_idx_start + size] self.current_batch_idx_start += size #更新偏移 return ret def get_test_batch(self): test_batch = self.__get_batch('test_batch') if self.onehot: return test_batch[0], self.__label_onehot(test_batch[1]) return test_batch定义几个创建tf矩阵和执行卷积、池化运算的函数 #创建tf权重矩阵 def get_weight_var(shape, stddev): return tf.Variable(tf.truncated_normal(shape, stddev=stddev)) #创建tf偏置项 def get_bias_var(shape): return tf.Variable(tf.constant(0.1), shape) #执行卷积运算 def conv2d(tensor, weight): return tf.nn.conv2d(tensor, weight, strides=[1, 1, 1, 1], padding="SAME") #最大池化 def max_pool(tensor): return tf.nn.max_pool(tensor, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="SAME")搭建一个两个卷积层、两个全连接层、带有dropout的卷积神经网络 X = tf.placeholder(tf.float32, [None, 32, 32, 3]) Y = tf.placeholder(tf.float32, [None, 10]) #第一个卷积层 weight1 = get_weight_var([5, 5, 3, 64], 0.05) bias1 = get_bias_var([64]) conv1 = conv2d(X, weight1) + bias1 pool1 = max_pool(conv1) active1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9, beta=0.75) #第二个卷积层 weight2 = get_weight_var([5, 5, 64, 64], 0.05) bias2 = get_bias_var([64]) conv2 = conv2d(pool1, weight2) + bias2 active2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.01 / 9, beta=0.75) pool2 = max_pool(active2) #第三个全连接层 keep_prob = tf.placeholder("float") input3 = tf.reshape(pool2, [-1, 8*8*64]) weight3 = get_weight_var([8*8*64, 128], 0.04) bias3 = get_bias_var([128]) output3 = tf.matmul(input3, weight3) + bias3 tf.nn.dropout(output3, keep_prob) active3 = tf.nn.relu(output3) #第四个全连接层 weight4 = get_weight_var([128, 10], 0.1) bias4 = get_bias_var([10]) output4 = tf.matmul(active3, weight4) + bias4 y_ = tf.nn.softmax(output4) #损失函数 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = Y, logits = y_)) train_step = tf.train.AdamOptimizer(1e-4).minimize(loss) #准确率 correct_prediction = tf.equal(tf.argmax(y_,1), tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #去除注释,运行代码块,sess将被重载!提前保存好 sess = tf.Session() sess.run(tf.initialize_all_variables()) cifar = Cifar(onehot=True)开始训练! #训练 for i in range(10000): batch_xs, batch_ys = cifar.get_next_batch(50) sess.run(train_step, feed_dict={ X:batch_xs, Y:batch_ys, keep_prob:0.7 }) if i % 100 == 0: print("iter:%d\tloss:%g accuracy:%g" % (i, \ sess.run(loss, feed_dict={ X:batch_xs, Y:batch_ys, keep_prob:1 }), \ sess.run(accuracy, feed_dict={ X:batch_xs, Y:batch_ys, keep_prob:1 })))这里训练次数越大越好。我在垃圾服务器上运行了30分钟才迭代完3000次。。 看一看我们的CNN是什么效果 #输出一个样本 batch_xs, batch_ys = cifar.get_next_batch(1) plt.figure("show pickle", figsize=(5,5)) plt.subplot(131) plt.imshow(batch_xs[0]) plt.axis('off') plt.subplot(132) plt.imshow(sess.run(conv1, feed_dict={ X:batch_xs, Y:batch_ys })[0].T[2].T) plt.axis('off') plt.subplot(133) plt.imshow(sess.run(conv2, feed_dict={ X:batch_xs, Y:batch_ys })[0].T[2].T) plt.axis('off') plt.show() y_actual = np.where(batch_ys[0] == 1)[0][0] y_cnn = sess.run(tf.argmax(y_, 1), feed_dict={ X:batch_xs, Y:batch_ys, keep_prob:1.0 })[0] print("人工智障寻思这好像是【", cifar.label_name_cn[y_cnn], "】") print("这玩意其实是个【", cifar.label_name_cn[y_actual], "】") if y_actual == y_cnn: print("\n还真叫它给猜对了\n") else: print("\n这个傻玩意,一会再让他学习10万次\n") pred_p = sess.run(y_, feed_dict={ X:batch_xs, Y:batch_ys, keep_prob:1.0 })[0] pred_p_cn = [cifar.label_name_cn[x] for x in pred_p.argsort()] pred_p_cn.reverse() print("CNN认为概率由高到低可能是:\n", pred_p_cn)保存我们的模型 #保存模型 !!!!会覆盖之前保存的,谨慎使用 if input("Enter YES to save") == "YES": saver=tf.train.Saver(max_to_keep=1) saver.save(sess,'./model/10.ckpt') else: print('canceled')以后可以用下面的代码恢复模型 #恢复模型!!!!谨慎使用 if input("Enter YES to restore") == "YES": saver=tf.train.Saver(max_to_keep=1) model_file=tf.train.latest_checkpoint('./model/') saver.restore(sess,model_file) else: print('canceled')下面的代码可以测试模型在测试集上的准确率 test_xs, test_ys = cifar.get_test_batch() print(sess.run(accuracy, feed_dict = { X:test_xs, Y:test_ys, keep_prob:1.0 })) |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |